import os
import yaml
import numpy as np
import argparse


def build_link_map(lattice):
    """
    Build a dict mapping ((x,y), mu) -> index in lattice array.
    """
    return {tuple(link): idx for idx, link in enumerate(lattice)}


def compute_wilson_loops(lattice, U, sizes, bc='periodic'):
    """
    Compute Wilson loops W(C) for given loop sizes on a lattice.
    lattice: array of [((x,y), mu)]
    U: array of link variables matching lattice indices
    sizes: list of loop edge lengths [1,2,3]
    bc: 'periodic' or 'open'
    Returns dict size -> list of W values.
    """
    # Determine lattice size from number of unique sites
    size_grid = int(np.sqrt(len({pos for pos, mu in lattice})))
    link_map = build_link_map(lattice)
    results = {L: [] for L in sizes}

    # Determine if U elements are scalar (U(1)) or matrices (SU(N)).
    sample_U = U[0]
    is_scalar = np.isscalar(sample_U) or getattr(sample_U, 'shape', ()) == ()
    if is_scalar:
        identity = 1 + 0j
    else:
        identity = np.eye(sample_U.shape[0], dtype=complex)

    for L in sizes:
        for x in range(size_grid):
            for y in range(size_grid):
                # Start with the identity.  For scalars we can assign
                # directly; for matrices use a copy to avoid modifying
                # the shared identity matrix.
                W = identity.copy() if not is_scalar else identity
                # 1) Right steps (mu=0)
                for i in range(L):
                    xi, yi = (x + i) % size_grid, y
                    idx = link_map[((xi, yi), 0)]
                    if is_scalar:
                        W *= U[idx]
                    else:
                        W = W @ U[idx]
                # 2) Up steps (mu=1)
                for j in range(L):
                    xi, yi = (x + L) % size_grid, (y + j) % size_grid
                    idx = link_map[((xi, yi), 1)]
                    if is_scalar:
                        W *= U[idx]
                    else:
                        W = W @ U[idx]
                # 3) Left steps (mu=0, inverse)
                for i in range(L):
                    xi, yi = (x + L - 1 - i) % size_grid, (y + L) % size_grid
                    idx = link_map[((xi, yi), 0)]
                    if is_scalar:
                        # For U(1), the inverse is simply the complex conjugate
                        W *= np.conjugate(U[idx])
                    else:
                        W = W @ np.conjugate(U[idx]).T
                # 4) Down steps (mu=1, inverse)
                for j in range(L):
                    xi, yi = x, (y + L - 1 - j) % size_grid
                    idx = link_map[((xi, yi), 1)]
                    if is_scalar:
                        W *= np.conjugate(U[idx])
                    else:
                        W = W @ np.conjugate(U[idx]).T
                # Take trace or scalar depending on representation
                if is_scalar:
                    val = W
                else:
                    val = np.trace(W)
                results[L].append(val)
    return results


def main(config_path: str):
    # Load config
    with open(config_path) as f:
        cfg = yaml.safe_load(f)

    # Resolve data_dir and results_dir relative to the directory of the config file
    base_dir = os.path.dirname(os.path.abspath(config_path))
    data_dir_cfg = cfg.get('data_dir', 'data')
    if os.path.isabs(data_dir_cfg):
        data_dir = data_dir_cfg
    else:
        data_dir = os.path.join(base_dir, data_dir_cfg)
    results_dir_cfg = cfg.get('results_dir', 'results')
    if os.path.isabs(results_dir_cfg):
        results_dir = results_dir_cfg
    else:
        results_dir = os.path.join(base_dir, results_dir_cfg)
    sizes = cfg.get('loop_sizes', [1, 2, 3])
    bc = cfg.get('boundary_conditions', 'periodic')
    gauge_groups = cfg.get('gauge_groups', ['U1'])

    # Load lattice
    lattice = np.load(os.path.join(data_dir, 'lattice.npy'), allow_pickle=True)

    os.makedirs(results_dir, exist_ok=True)

    for G in gauge_groups:
        # Load U_mu
        Umu = np.load(os.path.join(data_dir, f'Umu_{G}.npy'), allow_pickle=True)
        # Compute loops
        results = compute_wilson_loops(lattice, Umu, sizes, bc)
        # Save per-group results to CSV
        out_path = os.path.join(results_dir, f'wilson_{G}.csv')
        with open(out_path, 'w') as fout:
            fout.write('size,average,real,imag\n')
            for L, vals in results.items():
                avg = np.mean(vals)
                fout.write(f"{L},{avg},{np.real(avg)},{np.imag(avg)}\n")
        print(f'Saved Wilson loop results for {G} to {out_path}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Measure Wilson loops')
    parser.add_argument('--config', default='config.yaml', help='Path to config file')
    args = parser.parse_args()
    main(args.config)